import pandas as pd
import re, sys
from pandas.core.frame import DataFrame
import numpy as np
from sqlalchemy import insert
from function import numfeatures,ALL_features,AB_features,pl_iat,columns,numerical,AB_num
from sklearn.model_selection import train_test_split

# 对数据进行预处理,2 argv[]
# 输入带标签的原数据集，
# 输出预处理后的数据集，带时间,IP等原始信息的

basic_features = ['%dir','timeFirst','timeLast','duration','srcIP','srcIPOrg','srcPort','dstIP','dstIPOrg','dstPort','maxPktSz','maxIAT','aveIAT','stdIAT','pktps','ipMindIPID','ipMaxdIPID','ipMinTTL','ipMaxTTL','tcpFlwLssAckRcvdBytes','tcpInitWinSz','tcpMinWinSz','tcpMaxWinSz','tcpWinSzChgDirCnt','tcpMSS','tcpTmS','tcpTmER','tcpBtm','connF','dsMaxPl','dsMedianPl','dsRangePl','dsMaxIat','dsMeanIat','dsMedianIat']


# ------------预处理-------------

def choose_portion_features(data, features):
    """选择"""
    classname = []
    pointer = -1
    columns = data.columns.values
    while(re.match('class*', columns[pointer])):
        classname.insert(0, columns[pointer])
        pointer -= 1
    features.extend(classname)
    return data[[*features]]
def remove_duration_less_than(data,duration : float):
    """去除时长短于某个值的流"""
    d = data['duration'].map(float)
    return data[d > duration]
def remove_zero_payload(data):
    """去除负载为0的流"""
    payload_A = data['A_maxPktSz'].map(float)
    payload_B = data['B_maxPktSz'].map(float)
    return data[(payload_A>0)&(payload_B>0)]
def remove_zero_IAT(data):
    """去除IAT为0的流"""
    IAT_A = data['A_maxIAT'].map(float)
    IAT_B = data['B_maxIAT'].map(float)
    return data[(IAT_A != 0)&(IAT_B != 0)]
def remove_features(data,remove_list):
    all_feature=[feature for feature in data]
    return data[[feature for feature in all_feature if feature not in remove_list]]
    
def remove_unique_features(data):
    # 去除只有一个值的特征
    def is_unique_value(col):
        return all(col.isna()) or len(np.unique(col.values))==1
    drop_list=[]
    for col_name in data:
        if col_name=='class':
            continue
        if is_unique_value(data[col_name]):
            drop_list.append(col_name)
    return data.drop(drop_list,axis=1)
def choose_length_features(data):
    """选取长度特征"""
    pattern='ds.+Pl|PktSz|PktSize'
    length_features=[feature for feature in numerical if re.search(pattern,feature)]
    length_features.extend(['class'])
    return data[length_features]
def choose_time_features(data):
    """选取时间特征"""
    pattern='TTL|Tm|tm$|ds.+Iat|duration'
    time_features=['%dir']
    time_features=[feature for feature in numerical if re.search(pattern,feature)]
    time_features.extend(['class'])
    return data[time_features]

def gen_train_test(data,ratio,label_name='class'):
    """生成测试集和训练集"""
    cols=[col for col in data]
    excludes=[r'%dir']
    X=data[[col for col in cols if col != label_name and col not in excludes]].values
    y=data[[label_name]].values
    return train_test_split(X,y,test_size=ratio,stratify=y)
def ndarray2csv(data,file):
    DataFrame(data).to_csv(file,index=False,header=0)
#---------------------------------------
if __name__ == '__main__':
    # 读取数据
    data = pd.read_csv(sys.argv[1],low_memory=False,delimiter=',')
    data=remove_duration_less_than(data,0)
    data=remove_zero_IAT(data)
    data=remove_zero_payload(data)
    AB_num.extend(pl_iat)
    data=choose_portion_features(data, AB_num)
    data.to_csv(sys.argv[2], index=False, sep=',')
